import os
os.environ["TORCH_HOME"] = "/cache"
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.chdir('/TransferAttack-main')
from PIL import Image
import torch
import torchvision.models as models
import timm  
import torchvision.transforms as transforms
import pandas as pd

input_images_dir = '/TransferAttack-main/coco/coco_subset_classified/images'  
output_images_dir = os.path.join('/TransferAttack-main/coco/coco_subset_classified', 'consensus_images')  
output_dataset_dir = '/TransferAttack-main/coco/' 
image_size = (224, 224)


os.makedirs(output_images_dir, exist_ok=True)


def load_model(model_name):
    if model_name == 'resnet18':
        model = models.resnet18(pretrained=True)
    elif model_name == 'resnet101':
        model = models.resnet101(pretrained=True)
    elif model_name == 'resnext50_32x4d':
        model = models.resnext50_32x4d(pretrained=True)
    elif model_name == 'densenet121':
        model = models.densenet121(pretrained=True)
    elif model_name == 'vit_base_patch16_224':
        model = timm.create_model('vit_base_patch16_224', pretrained=True)
    elif model_name == 'pit_b_224':
        model = timm.create_model('pit_b_224', pretrained=True)
    elif model_name == 'visformer_small':
        model = timm.create_model('visformer_small', pretrained=True)
    elif model_name == 'swin_tiny_patch4_window7_224':
        model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True)
    else:
        raise ValueError(f"Unknown model: {model_name}")
    
    model.eval()  
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model


models_dict = {
    'resnet18': load_model('resnet18'),
    'resnet101': load_model('resnet101'),
    'resnext50_32x4d': load_model('resnext50_32x4d'),
    'densenet121': load_model('densenet121'),
    'vit_base_patch16_224': load_model('vit_base_patch16_224'),
    'pit_b_224': load_model('pit_b_224'),
    'visformer_small': load_model('visformer_small'),
    'swin_tiny_patch4_window7_224': load_model('swin_tiny_patch4_window7_224')
}


preprocess = transforms.Compose([
    transforms.ToTensor(),             
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])



results = [] 

for image_name in os.listdir(input_images_dir):
    if not image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
        continue
        
    image_path = os.path.join(input_images_dir, image_name)

    try:
     
        img = Image.open(image_path).convert('RGB')
        
   
        img_t = preprocess(img)
        batch_t = torch.unsqueeze(img_t, 0)
        

        predictions = {}
        for model_name, model in models_dict.items():
            device = next(model.parameters()).device 
            with torch.no_grad():
                out = model(batch_t.to(device))
            _, predicted_idx = torch.max(out, 1)
            predictions[model_name] = predicted_idx.item()
        

        pred_values = list(predictions.values())
        if all(p == pred_values[0] for p in pred_values):
      
            predicted_class_index = pred_values[0]
            
           
            new_image_path = os.path.join(output_images_dir, image_name)
            img.save(new_image_path)
            
      
            result = {
                'original_filename': image_name,
                'new_image_path': os.path.relpath(new_image_path, output_dataset_dir),
                'predicted_class_index': predicted_class_index
            }
    
            for model_name, pred in predictions.items():
                result[f'{model_name}_pred'] = pred
            results.append(result)
        else:
     
            continue

    except Exception as e:

        results.append({
            'original_filename': image_name,
            'new_image_path': 'ERROR',
            'predicted_class_index': -1,
            'error_message': str(e),
            'resnet18_pred': -1,
            'resnet101_pred': -1,
            'vit_pred': -1
        })


results_df = pd.DataFrame(results)
csv_output_path = os.path.join(output_dataset_dir, 'consensus_classification_results.csv')
results_df.to_csv(csv_output_path, index=False)


